import matplotlib.pyplot as pyplot
import numpy
import svgutils as SVG
from matplotlib.colors import LinearSegmentedColormap,ListedColormap
from matplotlib import rcParams, lines
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, BoundaryNorm,LogNorm
from matplotlib.ticker import MultipleLocator, ScalarFormatter,FuncFormatter,FormatStrFormatter
from scipy.optimize import fsolve as fsolve
import scipy.optimize as optimise
from matplotlib import gridspec
import scipy.interpolate as interp
from matplotlib.patches import ConnectionPatch,FancyBboxPatch,Rectangle
import os
from jqc import jqc_plot
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import scipy.constants as constants
from sympy.physics.wigner import wigner_3j,wigner_9j

h = constants.h
pi = numpy.pi
Wavelength=1064 #nm

Ts = 0.781 #transmission of s-polarised light through cell
Tp = 0.99468 #transmission of p-polarisation
ts = numpy.sqrt(Ts)
tp = numpy.sqrt(Tp)

jqc_plot.plot_style("normal")
cwd = os.path.dirname(os.path.abspath(__file__))


colour_dict_twk_blue = {
    "red" : [(0.0,244.0/255.0,244.0/255.0),
            (0.33,124.0/255.0,124.0/255.0),
            (0.66,0.0,0.0),
            (1.0,0.0,0.0)] ,
    "green" : [(0.0,234.0/255.0,234.0/255.0),
            (0.33,154.0/255.0,154.0/255.0),
            (0.66,70.0/255.0,70.0/255.0),
            (1.0,32.0/255.0,32.0/255.0)]  ,
    "blue" : [(0.0,168.0/255.0,168.0/255.0),
            (0.33,148.0/255.0,148/255.0),
            (0.66,127.0/255.0,127.0/255.0),
            (1.0,58.0/255.0,58.0/255.0)]
}

colour_dict_twk_red = {
    "red" : [(0.0,244.0/255.0,244.0/255.0),
            (0.33,229.0/255.0,229.0/255.0),
            (0.66,214.0/255.0,214.0/255.0),
            (1.0,170.0/255.0,170.0/255.0)] ,
    "green" : [(0.0,234.0/255.0,234.0/255.0),
            (0.33,177.0/255.0,177.0/255.0),
            (0.66,120.0/255.0,120.0/255.0),
            (1.0,43.0/255.0,43.0/255.0)]  ,
    "blue" : [(0.0,168.0/255.0,168.0/255.0),
            (0.33,145.0/255.0,145/255.0),
            (0.66,122.0/255.0,122.0/255.0),
            (1.0,74.0/255.0,74.0/255.0)]

}

JQC = {'red'     :(198.0/255.0, 62.0/255.0, 98.0/255.0), \
       'blue'    :(0.0/255.0, 70.0/255.0, 127.0/255.0), \
       'purple'  :(126.0/255.0, 29.0/255.0, 123.0/255.0), \
       'sand'  :(244./255.0, 234./255.0, 168./255.0), \
       'grayblue'  :(212./255.0, 213./255.0, 220./255.0), \
       'green'   :(45.0/255.0, 159.0/255.0, 60.0/255.0)}

colour_dict_twk_blue_alpha = colour_dict_twk_blue.copy()
colour_dict_twk_blue_alpha['alpha'] = ((0.0, 0.0,0.0),
                #   (0.25,1.0, 1.0),
                   (0.5, 1.0, 1.0),
                #   (0.75,1.0, 1.0),
                   (1.0, 1.0, 1.0))


RbCs_map_twk_blue = LinearSegmentedColormap("RbCs_map_tweak_blue",
                                            colour_dict_twk_blue_alpha)
pyplot.register_cmap(cmap=RbCs_map_twk_blue)

colour_dict_twk_red_alpha = colour_dict_twk_red.copy()
colour_dict_twk_red_alpha['alpha'] = ((0.0, 0.0,0.0),
                #   (0.25,1.0, 1.0),
                   (0.5, 1.0, 1.0),
                #   (0.75,1.0, 1.0),
                   (1.0, 1.0, 1.0))


RbCs_map_twk_red = LinearSegmentedColormap("RbCs_map_tweak_red",
                                            colour_dict_twk_red_alpha)
pyplot.register_cmap(cmap=RbCs_map_twk_red)


def make_segments(x, y):
    '''
    Create list of line segments from x and y coordinates, in the correct format
    for LineCollection:
    an array of the form   numlines x (points per line) x 2 (x and y) array
    '''

    points = numpy.array([x, y]).T.reshape(-1, 1, 2)
    segments = numpy.concatenate([points[:-1], points[1:]], axis=1)

    return segments

def colorline(x, y, z=None, cmap=pyplot.get_cmap('copper'),
        norm=pyplot.Normalize(0.0, 1.0), linewidth=3, alpha=1.0,legend=False):
    '''
    Plot a colored line with coordinates x and y
    Optionally specify colors in the array z
    Optionally specify a colormap, a norm function and a line width
    '''

    # Default colors equally spaced on [0,1]:
    if z is None:
        z = numpy.linspace(0.0, 1.0, len(x))

    # Special case if a single number:
    if not hasattr(z, "__iter__"):  #to check for numerical input this is a hack
        z = numpy.array([z])

    z = numpy.asarray(z)

    segments = make_segments(x, y)
    lc = LineCollection(segments, array=z, cmap=cmap, norm=norm,
                        linewidth=linewidth)

    ax = pyplot.gca()
    ax.add_collection(lc)

    return lc

def dipolez(Nmax,d):
    ''' Generates the induced dipole moment for a Rigid rotor '''
    shape = numpy.sum(numpy.array([2*x+1 for x in range(0,Nmax+1)]))
    Dmat = numpy.zeros((shape,shape),dtype= numpy.complex)
    i =0
    j =0
    for N1 in range(0,Nmax+1):
        for M1 in range(N1,-(N1+1),-1):
            for N2 in range(0,Nmax+1):
                for M2 in range(N2,-(N2+1),-1):
                    Dmat[i,j]=d*numpy.sqrt((2*N1+1)*(2*N2+1))*(-1)**(M1)*\
                    wigner_3j(N1,1,N2,-M1,0,M2)*wigner_3j(N1,1,N2,0,0,0)
                    j+=1
            j=0
            i+=1
    return Dmat
############################ PLOTTING ###########################
fig = pyplot.figure()
grid =gridspec.GridSpec(2,3,width_ratios=[1,0.01,0.05])
vmin=1e-2
vmax=1
cbar_ax = fig.add_subplot(grid[:,2])

########### BETA = magic ################
WattsPerVolt = 0.985
S_component = ts*numpy.cos(numpy.deg2rad( 54.7356) )
P_component = tp*numpy.sin(numpy.deg2rad( 54.7356) )

attn = S_component**2+P_component**2


BetaMagic= fig.add_subplot(grid[0,0])

File_path = cwd+"\\Magic\\"

Data = numpy.genfromtxt(File_path+r"lines.csv",delimiter=',').T

DT_int = Data[:,0].copy()
Lines = Data[:,1:].copy()

for i in range(0,len(Lines[0,:])):
    Lines[:,i] = (Lines[:,i]-Data[:,1])/h
try:
    Lines_int_P = numpy.genfromtxt(File_path+"TDM.csv",delimiter=',')
except IOError:
    IRb = 3/2
    ICs = 7/2
    dz = numpy.kron(dipolez(5,1),numpy.kron(numpy.identity(int(2*IRb+1)),
                                            numpy.identity(int(2*ICs+1))))
    Lines_int_P = numpy.zeros(Lines.shape)
    Data = numpy.load(File_path+r"States.npy")
    for i in range(32,3*32+1):
        for k in range(len(Lines[:,0])):
            Lines_int_P[:,i]=numpy.abs(numpy.dot(numpy.conjugate(Data[:,0,k]),
                                numpy.dot(dz,Data[:,i,k])))
    numpy.savetxt(File_path+"TDM.csv",Lines_int_P,delimiter=',')

Exp = numpy.genfromtxt(File_path+r"Exp.csv",delimiter=',')

Exp[:,0] = 2*(Exp[:,0]*attn*WattsPerVolt)/(pi*174e-6**2)

pyplot.errorbar(1e-7*Exp[:,0],Exp[:,1],yerr=Exp[:,2],
                    capsize=3.5,color='k',fmt='o')
####################################################

for i in range(0,len(Lines[0,:])):
	#1e3 is for kHz, in which the frequencies are recorded.
    if i ==0:
    	pyplot.plot(DT_int*1e-7,Lines[:,i]*1e-6,
                    color='k', zorder=0)
    else:
    	pyplot.plot(DT_int*1e-7,Lines[:,i]*1e-6,
                    color=(244.0/255.0, 234.0/255.0, 168.0/255.0),
                    alpha=.6, zorder=0)
for i in range(0,len(Lines[0,:])):
	#1e3 is for kHz, in which the frequencies are recorded.
	cl=colorline(DT_int*1e-7,Lines[:,i]*1e-6,3*Lines_int_P[:,i]**2,
                cmap='RbCs_map_tweak_blue',norm=LogNorm(vmin,vmax=1),
                linewidth=2.0)
BetaMagic.text(0.02,0.85,"$\\beta_\\mathrm{magic}$",
                    transform=BetaMagic.transAxes)

BetaMagic.text(0.91,0.85,"(a)",fontsize=20,
                    transform=BetaMagic.transAxes)

########### BETA = 47 ################
WattsPerVolt = 2.12

S_component = ts*numpy.cos(numpy.deg2rad(47) )
P_component = tp*numpy.sin(numpy.deg2rad(47) )

attn = S_component**2+P_component**2
Beta47 = fig.add_subplot(grid[1,0],sharex=BetaMagic,sharey=BetaMagic)

File_path = cwd+"\\47\\"

Data = numpy.genfromtxt(File_path+r"lines.csv",delimiter=',').T

DT_int = Data[:,0].copy()
Lines = Data[:,1:].copy()

for i in range(0,len(Lines[0,:])):
    Lines[:,i] = (Lines[:,i]-Data[:,1])/h

try:
    Lines_int_P = numpy.genfromtxt(File_path+"TDM.csv",delimiter=',')
except IOError:
    IRb = 3/2
    ICs = 7/2
    dz = numpy.kron(dipolez(5,1),numpy.kron(numpy.identity(int(2*IRb+1)),
                                            numpy.identity(int(2*ICs+1))))
    Lines_int_P = numpy.zeros(Lines.shape)
    Data = numpy.load(File_path+r"States.npy")
    for i in range(32,3*32+1):
        for k in range(len(Lines[:,0])):
            Lines_int_P[:,i]=numpy.abs(numpy.dot(numpy.conjugate(Data[:,0,k]),
                                numpy.dot(dz,Data[:,i,k])))
    numpy.savetxt(File_path+"TDM.csv",Lines_int_P,delimiter=',')


Exp = numpy.genfromtxt(File_path+r"Exp.csv",delimiter=',')

Exp[:,0] = 2*(Exp[:,0]*attn*WattsPerVolt)/(pi*174e-6**2)

pyplot.errorbar(1e-7*Exp[:,0],Exp[:,1],yerr=Exp[:,2],
                    capsize=3.5,color='k',fmt='o')

####################################################

for i in range(0,len(Lines[0,:])):
	#1e3 is for kHz, in which the frequencies are recorded.
    if i ==0:
    	pyplot.plot(DT_int*1e-7,Lines[:,i]*1e-6,
                    color='k', zorder=0)
    else:
    	pyplot.plot(DT_int*1e-7,Lines[:,i]*1e-6,
                    color=(244.0/255.0, 234.0/255.0, 168.0/255.0),
                    alpha=.6, zorder=0)
for i in range(0,len(Lines[0,:])):
	#1e3 is for kHz, in which the frequencies are recorded.
	cl=colorline(DT_int*1e-7,Lines[:,i]*1e-6,3*Lines_int_P[:,i]**2,
                cmap='RbCs_map_tweak_blue',norm=LogNorm(vmin,vmax=1),
                linewidth=2.0)

cbar1 = pyplot.colorbar(cl,cax=cbar_ax, pad=-.08)
cbar1.ax.set_title("$z$",color=jqc_plot.colours['blue'])
cbar1.ax.set_ylabel("Relative Transition Strength")

Beta47.text(0.02,0.85,"$\\beta=47^\\circ$",transform=Beta47.transAxes)
Beta47.text(0.91,0.85,"(b)",transform=Beta47.transAxes,fontsize=20)
#fig.text(0.18,0.5,"Transition Frequency (MHz)",transform=fig.transFigure,
#            rotation = 90,verticalalignment ="center")
fig.text(0.05,0.5,"Transition Frequency (MHz)",transform=fig.transFigure,
            rotation = 90,verticalalignment ="center")

Beta47.set_ylim(982.35,982.428)

Offset = 982
Beta47.ticklabel_format(axis='y',useOffset=Offset,style='plain')
Beta47.yaxis.offsetText.set_visible(False)

BetaMagic.ticklabel_format(axis='y',useOffset=Offset,style='plain')
BetaMagic.yaxis.offsetText.set_visible(False)

pyplot.setp(BetaMagic.get_xticklabels(),visible=False)

BetaMagic.text(0,1.05,"+982 MHz",transform=BetaMagic.transAxes)

Beta47.set_xlabel("Intensity (kW$\\,$cm$^{-2}$)")
BetaMagic.set_xlim(0,12.5)

pyplot.tight_layout()

pyplot.tight_layout()
pyplot.subplots_adjust(wspace=0,hspace=0.09,left=0.17,right=0.86,top=0.92)

pyplot.savefig(cwd+"\\BeyondMagic.pdf")
pyplot.savefig(cwd+"\\BeyondMagic.png")

pyplot.show()
